{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers installation\n",
    "! pip install transformers datasets\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to fine-tune a model for common downstream tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This guide will show you how to fine-tune 🤗 Transformers models for common downstream tasks. You will use the 🤗\n",
    "Datasets library to quickly load and preprocess the datasets, getting them ready for training with PyTorch and\n",
    "TensorFlow.\n",
    "\n",
    "Before you begin, make sure you have the 🤗 Datasets library installed. For more detailed installation instructions,\n",
    "refer to the 🤗 Datasets [installation page](https://huggingface.co/docs/datasets/installation.html). All of the\n",
    "examples in this guide will use 🤗 Datasets to load and preprocess a dataset.\n",
    "\n",
    "```bash\n",
    "pip install datasets\n",
    "```\n",
    "\n",
    "Learn how to fine-tune a model for:\n",
    "\n",
    "- [seq_imdb](#seq_imdb)\n",
    "- [tok_ner](#tok_ner)\n",
    "- [qa_squad](#qa_squad)\n",
    "\n",
    "<a id='seq_imdb'></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequence classification with IMDb reviews"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sequence classification refers to the task of classifying sequences of text according to a given number of classes. In\n",
    "this example, learn how to fine-tune a model on the [IMDb dataset](https://huggingface.co/datasets/imdb) to determine\n",
    "whether a review is positive or negative.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "For a more in-depth example of how to fine-tune a model for text classification, take a look at the corresponding\n",
    "[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification.ipynb)\n",
    "or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification-tf.ipynb).\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load IMDb dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The 🤗 Datasets library makes it simple to load a dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "imdb = load_dataset(\"imdb\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This loads a `DatasetDict` object which you can index into to view an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imdb[\"train\"][0]\n",
    "{'label': 1,\n",
    " 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as \"Teachers\". My 35 years in the teaching profession lead me to believe that Bromwell High\\'s satire is much closer to reality than is \"Teachers\". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\\'t!'\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next step is to tokenize the text into a readable format by the model. It is important to load the same tokenizer a\n",
    "model was trained with to ensure appropriately tokenized words. Load the DistilBERT tokenizer with the\n",
    "[AutoTokenizer](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoTokenizer) because we will eventually train a classifier using a pretrained [DistilBERT](https://huggingface.co/distilbert-base-uncased) model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that you have instantiated a tokenizer, create a function that will tokenize the text. You should also truncate\n",
    "longer sequences in the text to be no longer than the model's maximum input length:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "    return tokenizer(examples[\"text\"], truncation=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use 🤗 Datasets `map` function to apply the preprocessing function to the entire dataset. You can also set\n",
    "`batched=True` to apply the preprocessing function to multiple elements of the dataset at once for faster\n",
    "preprocessing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_imdb = imdb.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, pad your text so they are a uniform length. While it is possible to pad your text in the `tokenizer` function\n",
    "by setting `padding=True`, it is more efficient to only pad the text to the length of the longest element in its\n",
    "batch. This is known as **dynamic padding**. You can do this with the `DataCollatorWithPadding` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with the Trainer API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now load your model with the [AutoModelForSequenceClassification](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForSequenceClassification) class along with the number of expected labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, only three steps remain:\n",
    "\n",
    "1. Define your training hyperparameters in [TrainingArguments](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.TrainingArguments).\n",
    "2. Pass the training arguments to a [Trainer](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Trainer) along with the model, dataset, tokenizer, and data collator.\n",
    "3. Call `Trainer.train()` to fine-tune your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=0.01,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_imdb[\"train\"],\n",
    "    eval_dataset=tokenized_imdb[\"test\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fine-tuning with TensorFlow is just as easy, with only a few differences.\n",
    "\n",
    "Start by batching the processed examples together with dynamic padding using the [DataCollatorWithPadding](https://huggingface.co/docs/transformers/master/en/main_classes/data_collator#transformers.DataCollatorWithPadding) function.\n",
    "Make sure you set `return_tensors=\"tf\"` to return `tf.Tensor` outputs instead of PyTorch tensors!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "data_collator = DataCollatorWithPadding(tokenizer, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, convert your datasets to the `tf.data.Dataset` format with `to_tf_dataset`. Specify inputs and labels in the\n",
    "`columns` argument:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train_dataset = tokenized_imdb[\"train\"].to_tf_dataset(\n",
    "    columns=['attention_mask', 'input_ids', 'label'],\n",
    "    shuffle=True,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")\n",
    "\n",
    "tf_validation_dataset = tokenized_imdb[\"train\"].to_tf_dataset(\n",
    "    columns=['attention_mask', 'input_ids', 'label'],\n",
    "    shuffle=False,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up an optimizer function, learning rate schedule, and some training hyperparameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "import tensorflow as tf\n",
    "\n",
    "batch_size = 16\n",
    "num_epochs = 5\n",
    "batches_per_epoch = len(tokenized_imdb[\"train\"]) // batch_size\n",
    "total_train_steps = int(batches_per_epoch * num_epochs)\n",
    "optimizer, schedule = create_optimizer(\n",
    "    init_lr=2e-5, \n",
    "    num_warmup_steps=0, \n",
    "    num_train_steps=total_train_steps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load your model with the [TFAutoModelForSequenceClassification](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.TFAutoModelForSequenceClassification) class along with the number of expected labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForSequenceClassification\n",
    "model = TFAutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, fine-tune the model by calling `model.fit`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    tf_train_set,\n",
    "    validation_data=tf_validation_set,\n",
    "    epochs=num_train_epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='tok_ner'></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Token classification with WNUT emerging entities"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Token classification refers to the task of classifying individual tokens in a sentence. One of the most common token\n",
    "classification tasks is Named Entity Recognition (NER). NER attempts to find a label for each entity in a sentence,\n",
    "such as a person, location, or organization. In this example, learn how to fine-tune a model on the [WNUT 17](https://huggingface.co/datasets/wnut_17) dataset to detect new entities.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "For a more in-depth example of how to fine-tune a model for token classification, take a look at the corresponding\n",
    "[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification.ipynb)\n",
    "or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification-tf.ipynb).\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load WNUT 17 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the WNUT 17 dataset from the 🤗 Datasets library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "wnut = load_dataset(\"wnut_17\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A quick look at the dataset shows the labels associated with each word in the sentence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wnut[\"train\"][0]\n",
    "{'id': '0',\n",
    " 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    " 'tokens': ['@paulwalk', 'It', \"'s\", 'the', 'view', 'from', 'where', 'I', \"'m\", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "View the specific NER tags by:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_list = wnut[\"train\"].features[f\"ner_tags\"].feature.names\n",
    "label_list\n",
    "['O',\n",
    " 'B-corporation',\n",
    " 'I-corporation',\n",
    " 'B-creative-work',\n",
    " 'I-creative-work',\n",
    " 'B-group',\n",
    " 'I-group',\n",
    " 'B-location',\n",
    " 'I-location',\n",
    " 'B-person',\n",
    " 'I-person',\n",
    " 'B-product',\n",
    " 'I-product'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A letter prefixes each NER tag which can mean:\n",
    "\n",
    "- `B-` indicates the beginning of an entity.\n",
    "- `I-` indicates a token is contained inside the same entity (e.g., the `State` token is a part of an entity like\n",
    "  `Empire State Building`).\n",
    "- `0` indicates the token doesn't correspond to any entity."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you need to tokenize the text. Load the DistilBERT tokenizer with an [AutoTokenizer](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoTokenizer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the input has already been split into words, set `is_split_into_words=True` to tokenize the words into\n",
    "subwords:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_input = tokenizer(example[\"tokens\"], is_split_into_words=True)\n",
    "tokens = tokenizer.convert_ids_to_tokens(tokenized_input[\"input_ids\"])\n",
    "tokens\n",
    "['[CLS]', '@', 'paul', '##walk', 'it', \"'\", 's', 'the', 'view', 'from', 'where', 'i', \"'\", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The addition of the special tokens `[CLS]` and `[SEP]` and subword tokenization creates a mismatch between the\n",
    "input and labels. Realign the labels and tokens by:\n",
    "\n",
    "1. Mapping all tokens to their corresponding word with the `word_ids` method.\n",
    "2. Assigning the label `-100` to the special tokens `[CLS]` and ``[SEP]``` so the PyTorch loss function ignores\n",
    "   them.\n",
    "3. Only labeling the first token of a given word. Assign `-100` to the other subtokens from the same word.\n",
    "\n",
    "Here is how you can create a function that will realign the labels and tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_and_align_labels(examples):\n",
    "    tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
    "\n",
    "    labels = []\n",
    "    for i, label in enumerate(examples[f\"ner_tags\"]):\n",
    "        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:                            # Set the special tokens to -100.\n",
    "            if word_idx is None:\n",
    "                label_ids.append(-100)\n",
    "            elif word_idx != previous_word_idx:              # Only label the first token of a given word.\n",
    "                label_ids.append(label[word_idx])\n",
    "\n",
    "        labels.append(label_ids)\n",
    "\n",
    "    tokenized_inputs[\"labels\"] = labels\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now tokenize and align the labels over the entire dataset with 🤗 Datasets `map` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, pad your text and labels, so they are a uniform length:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with the Trainer API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load your model with the [AutoModelForTokenClassification](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForTokenClassification) class along with the number of expected labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n",
    "model = AutoModelForTokenClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=len(label_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gather your training arguments in [TrainingArguments](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.TrainingArguments):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect your model, training arguments, dataset, data collator, and tokenizer in [Trainer](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Trainer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_wnut[\"train\"],\n",
    "    eval_dataset=tokenized_wnut[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fine-tune your model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Batch your examples together and pad your text and labels, so they are a uniform length:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert your datasets to the `tf.data.Dataset` format with `to_tf_dataset`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train_set = tokenized_wnut[\"train\"].to_tf_dataset(\n",
    "    columns=[\"attention_mask\", \"input_ids\", \"labels\"],\n",
    "    shuffle=True,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")\n",
    "\n",
    "tf_validation_set = tokenized_wnut[\"validation\"].to_tf_dataset(\n",
    "    columns=[\"attention_mask\", \"input_ids\", \"labels\"],\n",
    "    shuffle=False,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the model with the [TFAutoModelForTokenClassification](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.TFAutoModelForTokenClassification) class along with the number of expected labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForTokenClassification\n",
    "model = TFAutoModelForTokenClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=len(label_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up an optimizer function, learning rate schedule, and some training hyperparameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "\n",
    "batch_size = 16\n",
    "num_train_epochs = 3\n",
    "num_train_steps = (len(tokenized_datasets[\"train\"]) // batch_size) * num_train_epochs\n",
    "optimizer, lr_schedule = create_optimizer(\n",
    "    init_lr=2e-5,\n",
    "    num_train_steps=num_train_steps,\n",
    "    weight_decay_rate=0.01,\n",
    "    num_warmup_steps=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call `model.fit` to fine-tune your model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    tf_train_set,\n",
    "    validation_data=tf_validation_set,\n",
    "    epochs=num_train_epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='qa_squad'></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Question Answering with SQuAD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are many types of question answering (QA) tasks. Extractive QA focuses on identifying the answer from the text\n",
    "given a question. In this example, learn how to fine-tune a model on the [SQuAD](https://huggingface.co/datasets/squad) dataset.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "For a more in-depth example of how to fine-tune a model for question answering, take a look at the corresponding\n",
    "[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb)\n",
    "or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering-tf.ipynb).\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load SQuAD dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the SQuAD dataset from the 🤗 Datasets library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "squad = load_dataset(\"squad\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take a look at an example from the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squad[\"train\"][0]\n",
    "{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},\n",
    " 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',\n",
    " 'id': '5733be284776f41900661182',\n",
    " 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',\n",
    " 'title': 'University_of_Notre_Dame'\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the DistilBERT tokenizer with an [AutoTokenizer](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoTokenizer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are a few things to be aware of when preprocessing text for question answering:\n",
    "\n",
    "1. Some examples in a dataset may have a very long `context` that exceeds the maximum input length of the model. You\n",
    "   can deal with this by truncating the `context` and set `truncation=\"only_second\"`.\n",
    "2. Next, you need to map the start and end positions of the answer to the original context. Set\n",
    "   `return_offset_mapping=True` to handle this.\n",
    "3. With the mapping in hand, you can find the start and end tokens of the answer. Use the `sequence_ids` method to\n",
    "   find which part of the offset corresponds to the question, and which part of the offset corresponds to the context.\n",
    "\n",
    "Assemble everything in a preprocessing function as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        questions,\n",
    "        examples[\"context\"],\n",
    "        max_length=384,\n",
    "        truncation=\"only_second\",\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "    answers = examples[\"answers\"]\n",
    "    start_positions = []\n",
    "    end_positions = []\n",
    "\n",
    "    for i, offset in enumerate(offset_mapping):\n",
    "        answer = answers[i]\n",
    "        start_char = answer[\"answer_start\"][0]\n",
    "        end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "\n",
    "        # Find the start and end of the context\n",
    "        idx = 0\n",
    "        while sequence_ids[idx] != 1:\n",
    "            idx += 1\n",
    "        context_start = idx\n",
    "        while sequence_ids[idx] == 1:\n",
    "            idx += 1\n",
    "        context_end = idx - 1\n",
    "\n",
    "        # If the answer is not fully inside the context, label it (0, 0)\n",
    "        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n",
    "            start_positions.append(0)\n",
    "            end_positions.append(0)\n",
    "        else:\n",
    "            # Otherwise it's the start and end token positions\n",
    "            idx = context_start\n",
    "            while idx <= context_end and offset[idx][0] <= start_char:\n",
    "                idx += 1\n",
    "            start_positions.append(idx - 1)\n",
    "\n",
    "            idx = context_end\n",
    "            while idx >= context_start and offset[idx][1] >= end_char:\n",
    "                idx -= 1\n",
    "            end_positions.append(idx + 1)\n",
    "\n",
    "    inputs[\"start_positions\"] = start_positions\n",
    "    inputs[\"end_positions\"] = end_positions\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apply the preprocessing function over the entire dataset with 🤗 Datasets `map` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad[\"train\"].column_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Batch the processed examples together:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import default_data_collator\n",
    "data_collator = default_data_collator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with the Trainer API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load your model with the [AutoModelForQuestionAnswering](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForQuestionAnswering) class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "model = AutoModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gather your training arguments in [TrainingArguments](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.TrainingArguments):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect your model, training arguments, dataset, data collator, and tokenizer in [Trainer](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Trainer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_squad[\"train\"],\n",
    "    eval_dataset=tokenized_squad[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fine-tune your model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune with TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Batch the processed examples together with a TensorFlow default data collator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.data.data_collator import tf_default_collator\n",
    "data_collator = tf_default_collator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert your datasets to the `tf.data.Dataset` format with the `to_tf_dataset` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train_set = tokenized_squad[\"train\"].to_tf_dataset(\n",
    "    columns=[\"attention_mask\", \"input_ids\", \"start_positions\", \"end_positions\"],\n",
    "    dummy_labels=True,\n",
    "    shuffle=True,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")\n",
    "\n",
    "tf_validation_set = tokenized_squad[\"validation\"].to_tf_dataset(\n",
    "    columns=[\"attention_mask\", \"input_ids\", \"start_positions\", \"end_positions\"],\n",
    "    dummy_labels=True,\n",
    "    shuffle=False,\n",
    "    batch_size=16,\n",
    "    collate_fn=data_collator,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up an optimizer function, learning rate schedule, and some training hyperparameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "\n",
    "batch_size = 16\n",
    "num_epochs = 2\n",
    "total_train_steps = (len(tokenized_squad[\"train\"]) // batch_size) * num_epochs\n",
    "optimizer, schedule = create_optimizer(\n",
    "    init_lr=2e-5, \n",
    "    num_warmup_steps=0, \n",
    "    num_train_steps=total_train_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load your model with the [TFAutoModelForQuestionAnswering](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.TFAutoModelForQuestionAnswering) class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForQuestionAnswering\n",
    "model = TFAutoModelForQuestionAnswering(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call `model.fit` to fine-tune the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    tf_train_set,\n",
    "    validation_data=tf_validation_set,\n",
    "    epochs=num_train_epochs,\n",
    ")"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
